{
  "cells": [
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "# code by Tae Hwan Jung @graykode\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "\n",
        "def make_batch():\n",
        "    input_batch, target_batch = [], []\n",
        "\n",
        "    for seq in seq_data:\n",
        "        input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n",
        "        target = word_dict[seq[-1]] # 'e' is target\n",
        "        input_batch.append(np.eye(n_class)[input])\n",
        "        target_batch.append(target)\n",
        "\n",
        "    return input_batch, target_batch\n",
        "\n",
        "class TextLSTM(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TextLSTM, self).__init__()\n",
        "\n",
        "        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)\n",
        "        self.W = nn.Linear(n_hidden, n_class, bias=False)\n",
        "        self.b = nn.Parameter(torch.ones([n_class]))\n",
        "\n",
        "    def forward(self, X):\n",
        "        input = X.transpose(0, 1)  # X : [n_step, batch_size, n_class]\n",
        "\n",
        "        hidden_state = torch.zeros(1, len(X), n_hidden)  # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "        cell_state = torch.zeros(1, len(X), n_hidden)     # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "\n",
        "        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))\n",
        "        outputs = outputs[-1]  # [batch_size, n_hidden]\n",
        "        model = self.W(outputs) + self.b  # model : [batch_size, n_class]\n",
        "        return model\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    n_step = 3 # number of cells(= number of Step)\n",
        "    n_hidden = 128 # number of hidden units in one cell\n",
        "\n",
        "    char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']\n",
        "    word_dict = {n: i for i, n in enumerate(char_arr)}\n",
        "    number_dict = {i: w for i, w in enumerate(char_arr)}\n",
        "    n_class = len(word_dict)  # number of class(=number of vocab)\n",
        "\n",
        "    seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']\n",
        "\n",
        "    model = TextLSTM()\n",
        "\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "\n",
        "    input_batch, target_batch = make_batch()\n",
        "    input_batch = torch.FloatTensor(input_batch)\n",
        "    target_batch = torch.LongTensor(target_batch)\n",
        "\n",
        "    # Training\n",
        "    for epoch in range(1000):\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        output = model(input_batch)\n",
        "        loss = criterion(output, target_batch)\n",
        "        if (epoch + 1) % 100 == 0:\n",
        "            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "    inputs = [sen[:3] for sen in seq_data]\n",
        "\n",
        "    predict = model(input_batch).data.max(1, keepdim=True)[1]\n",
        "    print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "anaconda-cloud": {},
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}